#include<iostream>
#include<stdio.h>
#include<cstring>
#include<algorithm>
#include<cmath>
using namespace std;
long long m,n,p,k,s;
int main()
{
	freopen("bpmp.in","r",stdin);
	freopen("bpmp.out","w",stdout);
	cin>>n>>m;
	k=max(n,m);
	s=min(n,m);
	if(n>=m)
	{
		p=(m-1+(n-1)*m)%998244353;
	}
	else p=(n-1+(m-1)*n)%998244353;
	cout<<p<<endl;
	return 0;
}
